Skip to content

make getform work (without DirichletBCs) for DOLFINx#200

Draft
jorgensd wants to merge 30 commits into
firedrakeproject:masterfrom
jorgensd:dokken/getForm
Draft

make getform work (without DirichletBCs) for DOLFINx#200
jorgensd wants to merge 30 commits into
firedrakeproject:masterfrom
jorgensd:dokken/getForm

Conversation

@jorgensd
Copy link
Copy Markdown
Contributor

@jorgensd jorgensd commented Mar 9, 2026

Minimal example:

from mpi4py import MPI
import dolfinx
import ufl
from irksome import GaussLegendre, getForm, Dt, MeshConstant
from irksome.tools import get_stage_space
from ufl import pi, atan, div, grad, inner, dx

butcher_tableau = GaussLegendre(2)
N = 64

x0 = 0.0
x1 = 10.0
y0 = 0.0
y1 = 10.0

msh = dolfinx.mesh.create_rectangle(MPI.COMM_WORLD, [[x0,y0],[x1,y1]],[N, N])
V = dolfinx.fem.functionspace(msh, ("Lagrange", 1))
x, y = ufl.SpatialCoordinate(msh)

MC = MeshConstant(msh, backend="dolfinx")
dt = MC.Constant(10 / N)
t = MC.Constant(0.0)

Constant = lambda val: dolfinx.fem.Constant(msh, val)
S = Constant(2.0)
C = Constant(1000.0)


B = (x-Constant(x0))*(x-Constant(x1))*(y-Constant(y0))*(y-Constant(y1))/C
R = (x * x + y * y) ** 0.5
uexact = B * atan(t)*(pi / 2.0 - atan(S * (R - t)))

rhs = Dt(uexact) - div(grad(uexact))

u = dolfinx.fem.Function(V)
u.interpolate(dolfinx.fem.Expression(uexact, V.element.interpolation_points))

v = ufl.TestFunction(V)
F = inner(Dt(u), v)*dx + inner(grad(u), grad(v))*dx - inner(rhs, v)*dx

bc = []
# bc = DirichletBC(V, 0, "on_boundary")

# Get the function space for the stage-coupled problem and a function to hold the stages we're computing::

Vbig = get_stage_space(V, butcher_tableau.num_stages, backend="dolfinx")
k = dolfinx.fem.Function(Vbig)

# Get the variational form and bcs for the stage-coupled variational problem::

Fnew, bcnew = getForm(F, butcher_tableau, t, dt, u, k, bcs=bc, backend="dolfinx")

Comment thread irksome/stage_derivative.py
Comment thread irksome/stage_derivative.py Outdated
Comment thread irksome/form_manipulation.py Outdated
Comment thread irksome/backends/firedrake.py Outdated
Comment thread irksome/backend.py
Comment thread irksome/backend.py Outdated
Comment thread irksome/backends/firedrake.py Outdated
Comment thread irksome/backends/firedrake.py Outdated
Comment thread irksome/bcs.py Outdated
Comment thread irksome/stage_derivative.py Outdated
Comment thread irksome/__init__.py Outdated
Co-authored-by: Copilot <copilot@github.com>
Comment thread irksome/backends/firedrake.py Outdated
Comment thread irksome/backends/firedrake.py Outdated
Comment thread irksome/stage_derivative.py Outdated
Comment thread irksome/stage_derivative.py Outdated
- `bcnew`, a list of :class:`firedrake.DirichletBC` or :class:`EquationBC`
objects to be posed on the stages
"""
backend_cls = get_backend(backend)
Copy link
Copy Markdown
Collaborator

@pbrubeck pbrubeck May 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we try to consistenty set a backend variable name throught different files? In base_time_stepper we repeatedly call self._backend, but here we define backend_cls.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say it depends on wether we are within a class (that takes backend at construction) or a function that takes backend as an argument.
I think using the string for backend to stand-alone functions (and then fetching the backend class), is a good option, and for classes set a self._backend at construction. However, open to alternatives.

assert V == backend_cls.get_function_space(u0)

c = vecconst(butch.c)
c = vecconst(butch.c, backend=backend)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we implement backend_cls.vecconst? The name of this function could also be improved, especially because this function returns ufl.zero for numeric values equal to zero.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I did it in the given way, with backend as input, is that the UFL operations in vecconst is the same for both implementations, i.e.

import numpy as np
from .backend import get_backend, Backend
import ufl


def MeshConstant(msh, backend: str = "firedrake"):
    mc_backend = get_backend(backend)
    return mc_backend.MeshConstant(msh)


def ConstantOrZero(
    x: float | complex,
    MC: Backend.MeshConstant | None = None,
    backend: str = "firedrake",
) -> ufl.core.expr.Expr:
    backend_impl = get_backend(backend)
    const = backend_impl.get_mesh_constant(MC)
    return ufl.zero() if abs(complex(x)) < 1.0e-10 else const(x)


vecconst = np.vectorize(ConstantOrZero)

If aimed to reduce the amount of duplicate code across the modules.
If you think another option is better I am open for it (but it would be nice to not have to have to versions of ufl.zero() if abs(complex(x)) < 1.0e-10 else const(x) as I feel they can easily divergece (especially the floating point tolerance).

Comment thread irksome/stage_derivative.py
Comment thread irksome/stage_derivative.py Outdated
@pbrubeck
Copy link
Copy Markdown
Collaborator

I think this PR should not apply changes that are merely (ruff) formatting. This makes it a bit harder to review.

Comment thread irksome/backends/firedrake.py Outdated
Comment thread irksome/tools.py Outdated
Comment thread irksome/backends/firedrake.py Outdated
Comment on lines +101 to +113
Function = firedrake.Function

DirichletBC = firedrake.DirichletBC

norm = firedrake.norm

assemble = firedrake.assemble

replace = firedrake.replace

derivative = firedrake.derivative
TestFunction = firedrake.TestFunction
TrialFunction = firedrake.TrialFunction
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to define varibales? Can these not be just imports?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They can, I added them here for clarity in case we ever need to do more than just reference them directly. I'm happy to adapt to whatever coding style you would like.

Comment thread irksome/bcs.py
Comment on lines +124 to +131
return type(self)(
V,
g,
sub_domain,
self.bounds,
self.solver_parameters,
backend=self._backend_name,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop these changes

Comment thread irksome/galerkin_stepper.py Outdated
Comment thread irksome/galerkin_stepper.py Outdated
Comment on lines +430 to +432
F_remainder = expand_time_derivatives(
split_form.remainder, t=t, timedep_coeffs=()
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop these changes

onscale_factor=1.2,
safety_factor=0.9,
gamma0_params=None,
backend_cls: str = "firedrake",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only add extra kwarg

Comment on lines +390 to +404
assert butcher_tableau.btilde is not None
super(AdaptiveTimeStepper, self).__init__(
F,
butcher_tableau,
t,
dt,
u0,
bcs=bcs,
appctx=appctx,
solver_parameters=solver_parameters,
bc_type=bc_type,
splitting=splitting,
nullspace=nullspace,
**kwargs,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop these changes

Comment thread irksome/stage_value.py
@@ -1,3 +1,3 @@
# formulate RK methods to solve for stage values rather than the stage derivatives.
import numpy
from firedrake import TestFunction
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Comment thread irksome/labeling.py
@@ -2,6 +2,7 @@
from ufl import BaseForm, Form, FormSum
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop all changes here. The correct fix is to expose LabelledForm.arguments()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you fix that I'll revert this. Currently this is needed to get any work done on the DOLFINx end of things.

Comment thread irksome/base_time_stepper.py Outdated
Co-authored-by: Jørgen Schartum Dokken <dokken92@gmail.com>
Comment thread irksome/base_time_stepper.py Outdated
Co-authored-by: Jørgen Schartum Dokken <dokken92@gmail.com>
Comment thread irksome/base_time_stepper.py Outdated
Comment on lines +247 to +250
self.sample_values = [
sum(ks[j] * vander[j, i] for j in range(len(ks)))
for i in range(num_samples)
]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.sample_values = [
sum(ks[j] * vander[j, i] for j in range(len(ks)))
for i in range(num_samples)
]
self.sample_values = [sum(ks[j] * vander[j, i] for j in range(len(ks)))
for i in range(num_samples)]

Comment thread irksome/bcs.py
field = 0 if len(V) == 1 else bc.function_space_index()
comp = (bc.function_space().component,)
ws = stages.subfunctions[field::len(V)]
ws = stages.subfunctions[field :: len(V)]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ws = stages.subfunctions[field :: len(V)]
ws = stages.subfunctions[field::len(V)]

Comment thread irksome/bcs.py
Comment on lines +124 to +131
return type(self)(
V,
g,
sub_domain,
self.bounds,
self.solver_parameters,
backend=self._backend_name,
)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return type(self)(
V,
g,
sub_domain,
self.bounds,
self.solver_parameters,
backend=self._backend_name,
)
return type(self)(V, g, sub_domain, self.bounds, self.solver_parameters)

Comment thread irksome/stage_derivative.py Outdated
Comment thread irksome/stage_derivative.py Outdated
jorgensd and others added 3 commits May 18, 2026 13:12
Co-authored-by: Jørgen Schartum Dokken <dokken92@gmail.com>
problem: dolfinx.fem.petsc.LinearProblem | dolfinx.fem.petsc.NonlinearProblem,
**kwargs,
):
"""Create a linear variational solver that uses PETSc KSP."""
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Create a linear variational solver that uses PETSc KSP."""
"""Create a variational solver that uses PETSc SNES or KSP."""


from petsc4py import PETSc
from .tools import AI, getNullspace, flatten_dats, split_stages
from .labeling import as_form
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from .labeling import as_form
try:
from .labeling import as_form
except ImportError:
as_form = lambda x: x

Comment on lines +243 to +246
self.sample_values = [
sum(ks[j] * vander[j, i] for j in range(len(ks)))
for i in range(num_samples)
]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.sample_values = [
sum(ks[j] * vander[j, i] for j in range(len(ks)))
for i in range(num_samples)
]
self.sample_values = [sum(ks[j] * vander[j, i] for j in range(len(ks)))
for i in range(num_samples)]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants